import os
import openai
import random
import numpy as np
import json
import jsonlines
import time
from tqdm import tqdm
from rank_bm25 import BM25Okapi
import threading

# OPENAI_API_KEY = "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"
# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
openai.api_key = OPENAI_API_KEY

start_prompt = '''
You need to pick up some sentences from a list of captions most related to given cue. Here are some examples.
'''


def ask_gpt4(question, thread_id, file_lock, line, unanswered_questions):
    messages=[{"role": "user", "content": question}]
    attempt_time = 0
    max_time = 20
    while attempt_time < max_time:
        try:
            response =  openai.ChatCompletion.create(
                            model="gpt-4",
                            max_tokens=1000,
                            temperature=1.2,
                            messages = messages)
            answer = response["choices"][0]["message"]["content"]

            with file_lock:
                with open('./gpt4_ans/winogavil/casenum_icl/8/swow/test.jsonl','a') as outfile:
                    line['gpt4'] = answer
                    outfile.write(json.dumps(line) + "\n")
                    break

        except openai.error.RateLimitError: # Rate limit exceeded
            attempt_time += 1
            time.sleep(0.2)
        except openai.error.Timeout: # Rate limit exceeded
            attempt_time += 1
            time.sleep(0.2)
        except openai.error.OpenAIError:
            attempt_time += 1
            raise Exception("Sorry, a problem happened")
    if attempt_time == max_time:
        unanswered_questions.append((question, thread_id, line))
        

def read_jsonline(sample_file):
    samples = []
    for line in sample_file.iter():
        sample = "The given cue is %s, the selected number of case is %d, and the captions %s. The corresponding correct labels are %s" %(line['cue'], np.count_nonzero(line['labels']), str(line['explanations']), line['labels'])
        samples.append(sample)
    return samples
        
if __name__=="__main__":
    dataset = jsonlines.open('./data/winogavil/casenum_icl/swow/test_top10.jsonl')
    sample_file = jsonlines.open('./data/winogavil/cb_icl/swow/train.jsonl')
    corpus = read_jsonline(sample_file)
    file_lock = threading.Lock()
    threads = []
    unanswered_questions = []
    with tqdm(desc='Process', unit='it', total=84) as pbar: #5_6: (260); 10_12: (85); swow: (84)
        num = 0
        for line in dataset.iter():
            top_index = line['mm_icl']
            top_index = top_index[:8]
            samples_prompt = ''''''
            for id in top_index:
                sample_prompt = '''%s''' %(corpus[id])
                samples_prompt = f'''{samples_prompt}{sample_prompt}'''
            captions = line['explanations']
            cue = line['cue']
            labels = line['labels']
            k = np.count_nonzero(labels)
            Question_part1 = '''\nNow choose the top %d sentences most related to the cue %s from captions: %s. ''' % (k, cue, str(captions))
            Question_part2 = '''Directly return the %d sentences as answer.''' %(k)
            content = f'''{start_prompt}{samples_prompt}{Question_part1}{Question_part2}'''
            # print(content)
            thread = threading.Thread(target=ask_gpt4, args=(content, num+1, file_lock, line, unanswered_questions))
            threads.append(thread)
            thread.start()
            num = num+1
            pbar.update()

        for thread in threads:
            thread.join()

        if unanswered_questions:
            retry_threads = []
            for question, thread_id, line in unanswered_questions:
                retry_thread = threading.Thread(target=ask_gpt4, args=(question, thread_id, file_lock, line, []))
                retry_threads.append(retry_thread)
                retry_thread.start()

            for thread in retry_threads:
                thread.join()